{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "c971744c-c456-41e4-82d2-559079585c16",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "from transformers import AutoTokenizer\n",
    "\n",
    "import torch\n",
    "from torch.utils.data import DataLoader, Dataset\n",
    "\n",
    "from transformers import AutoConfig, AutoModelForMaskedLM\n",
    "from transformers import Trainer, TrainingArguments\n",
    "from transformers import DataCollatorForLanguageModeling"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "46208b12",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "# 读取数据\n",
    "data_list = []\n",
    "label_map = {}\n",
    "\n",
    "full_data  = json.load(open('../data/train_data.json', 'r', encoding=\"utf-8\"))\n",
    "data_list = full_data['cls']['FinancialClassification']['data_list']\n",
    "i = 0\n",
    "for k,v in full_data['cls']['FinancialClassification']['label_mappings'].items():\n",
    "    label_map[k] = i\n",
    "    i += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "cc804bdf-11c5-4d62-b4f8-b43871c59115",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义模型\n",
    "model_checkpoint = \"hfl/chinese-roberta-wwm-ext\"\n",
    "tokenizer_checkpoint = \"hfl/chinese-roberta-wwm-ext\"\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "3ccbef85-aae9-44d6-be5e-707db6162007",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "loading configuration file https://huggingface.co/hfl/chinese-roberta-wwm-ext/resolve/main/config.json from cache at C:\\Users\\Administrator/.cache\\huggingface\\transformers\\b64aa51c20341fe5461d0663677dd1527cc8fb71c482d06ee75406857d0ed53a.ec9dcc5b0fb354ff7e07c612e71b31e4e41d5ffafe36e4ac9b37959db8873460\n",
      "Model config BertConfig {\n",
      "  \"_name_or_path\": \"hfl/chinese-roberta-wwm-ext\",\n",
      "  \"architectures\": [\n",
      "    \"BertForMaskedLM\"\n",
      "  ],\n",
      "  \"attention_probs_dropout_prob\": 0.1,\n",
      "  \"bos_token_id\": 0,\n",
      "  \"classifier_dropout\": null,\n",
      "  \"directionality\": \"bidi\",\n",
      "  \"eos_token_id\": 2,\n",
      "  \"hidden_act\": \"gelu\",\n",
      "  \"hidden_dropout_prob\": 0.1,\n",
      "  \"hidden_size\": 768,\n",
      "  \"initializer_range\": 0.02,\n",
      "  \"intermediate_size\": 3072,\n",
      "  \"layer_norm_eps\": 1e-12,\n",
      "  \"max_position_embeddings\": 512,\n",
      "  \"model_type\": \"bert\",\n",
      "  \"num_attention_heads\": 12,\n",
      "  \"num_hidden_layers\": 12,\n",
      "  \"output_past\": true,\n",
      "  \"pad_token_id\": 0,\n",
      "  \"pooler_fc_size\": 768,\n",
      "  \"pooler_num_attention_heads\": 12,\n",
      "  \"pooler_num_fc_layers\": 3,\n",
      "  \"pooler_size_per_head\": 128,\n",
      "  \"pooler_type\": \"first_token_transform\",\n",
      "  \"position_embedding_type\": \"absolute\",\n",
      "  \"transformers_version\": \"4.18.0\",\n",
      "  \"type_vocab_size\": 2,\n",
      "  \"use_cache\": true,\n",
      "  \"vocab_size\": 21128\n",
      "}\n",
      "\n",
      "loading file https://huggingface.co/hfl/chinese-roberta-wwm-ext/resolve/main/vocab.txt from cache at C:\\Users\\Administrator/.cache\\huggingface\\transformers\\92a56e79ec6564fd501527ed88ca336637eb4bfeb28d10580c3bbdfb7889a032.accd894ff58c6ff7bd4f3072890776c14f4ea34fcc08e79cd88c2d157756dceb\n",
      "loading file https://huggingface.co/hfl/chinese-roberta-wwm-ext/resolve/main/tokenizer.json from cache at C:\\Users\\Administrator/.cache\\huggingface\\transformers\\e6278a884ec926a36ddf8d3cc3c598a65dd410c8c01be870468c4f2f71bee0d7.660ed5c7513bf13d4607410502a84e0de517eb889ff8c401068a1688868e1ccb\n",
      "loading file https://huggingface.co/hfl/chinese-roberta-wwm-ext/resolve/main/added_tokens.json from cache at C:\\Users\\Administrator/.cache\\huggingface\\transformers\\87c7eedd995b4bae2c34df3baf2cbd5df5496bed675126427849c72e590f5574.5cc6e825eb228a7a5cfd27cb4d7151e97a79fb962b31aaf1813aa102e746584b\n",
      "loading file https://huggingface.co/hfl/chinese-roberta-wwm-ext/resolve/main/special_tokens_map.json from cache at C:\\Users\\Administrator/.cache\\huggingface\\transformers\\d521373fc7ac35f63d56cf303de74a202403dcf1aaa792cd01f653694be59563.dd8bd9bfd3664b530ea4e645105f557769387b3da9f79bdb55ed556bdd80611d\n",
      "loading file https://huggingface.co/hfl/chinese-roberta-wwm-ext/resolve/main/tokenizer_config.json from cache at C:\\Users\\Administrator/.cache\\huggingface\\transformers\\5dedf24c46ec573f8a27ddfa6e737869ec8df462a9a7682b35ded34301c5bdc8.d23f50bbddc3fb34db5a76d47fa9bdd5d75bf4201ad2d49abbcca25629b3e562\n",
      "loading configuration file https://huggingface.co/hfl/chinese-roberta-wwm-ext/resolve/main/config.json from cache at C:\\Users\\Administrator/.cache\\huggingface\\transformers\\b64aa51c20341fe5461d0663677dd1527cc8fb71c482d06ee75406857d0ed53a.ec9dcc5b0fb354ff7e07c612e71b31e4e41d5ffafe36e4ac9b37959db8873460\n",
      "Model config BertConfig {\n",
      "  \"_name_or_path\": \"hfl/chinese-roberta-wwm-ext\",\n",
      "  \"architectures\": [\n",
      "    \"BertForMaskedLM\"\n",
      "  ],\n",
      "  \"attention_probs_dropout_prob\": 0.1,\n",
      "  \"bos_token_id\": 0,\n",
      "  \"classifier_dropout\": null,\n",
      "  \"directionality\": \"bidi\",\n",
      "  \"eos_token_id\": 2,\n",
      "  \"hidden_act\": \"gelu\",\n",
      "  \"hidden_dropout_prob\": 0.1,\n",
      "  \"hidden_size\": 768,\n",
      "  \"initializer_range\": 0.02,\n",
      "  \"intermediate_size\": 3072,\n",
      "  \"layer_norm_eps\": 1e-12,\n",
      "  \"max_position_embeddings\": 512,\n",
      "  \"model_type\": \"bert\",\n",
      "  \"num_attention_heads\": 12,\n",
      "  \"num_hidden_layers\": 12,\n",
      "  \"output_past\": true,\n",
      "  \"pad_token_id\": 0,\n",
      "  \"pooler_fc_size\": 768,\n",
      "  \"pooler_num_attention_heads\": 12,\n",
      "  \"pooler_num_fc_layers\": 3,\n",
      "  \"pooler_size_per_head\": 128,\n",
      "  \"pooler_type\": \"first_token_transform\",\n",
      "  \"position_embedding_type\": \"absolute\",\n",
      "  \"transformers_version\": \"4.18.0\",\n",
      "  \"type_vocab_size\": 2,\n",
      "  \"use_cache\": true,\n",
      "  \"vocab_size\": 21128\n",
      "}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "class PreTrainDataset(Dataset):\n",
    "    def __init__(self, data_list, tokenizer, max_seq_len, label_map):\n",
    "        super(PreTrainDataset, self).__init__()\n",
    "        self.data_list = data_list\n",
    "        self.len = len(data_list)\n",
    "        self.tokenizer = tokenizer\n",
    "        self.max_seq_len = max_seq_len\n",
    "        self.label_map = label_map\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        example = self.data_list[index]\n",
    "        text = example['text_a']\n",
    "        data = self.tokenizer.encode_plus(text, return_token_type_ids=True, return_attention_mask=True, padding = 'max_length', max_length=self.max_seq_len)\n",
    "        \n",
    "        result = {'input_ids':  torch.tensor(data['input_ids'], dtype=torch.long), \n",
    "                'token_type_ids':  torch.tensor(data['token_type_ids'], dtype=torch.long), \n",
    "                'attention_mask':  torch.tensor(data['attention_mask'], dtype=torch.long)}\n",
    "        return result\n",
    "    \n",
    "    def __len__(self):\n",
    "        return self.len\n",
    "\n",
    "# 获取分词器\n",
    "tokenizer = AutoTokenizer.from_pretrained(tokenizer_checkpoint)\n",
    "# 构建dataloader\n",
    "train_dataset = PreTrainDataset(data_list, tokenizer, 32, label_map)\n",
    "train_dataloader = DataLoader(train_dataset, batch_size=2, num_workers=0)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "f20be05f-f05e-44e7-8a78-aa0c6ee1377d",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "loading configuration file https://huggingface.co/hfl/chinese-roberta-wwm-ext/resolve/main/config.json from cache at C:\\Users\\Administrator/.cache\\huggingface\\transformers\\b64aa51c20341fe5461d0663677dd1527cc8fb71c482d06ee75406857d0ed53a.ec9dcc5b0fb354ff7e07c612e71b31e4e41d5ffafe36e4ac9b37959db8873460\n",
      "Model config BertConfig {\n",
      "  \"_name_or_path\": \"hfl/chinese-roberta-wwm-ext\",\n",
      "  \"architectures\": [\n",
      "    \"BertForMaskedLM\"\n",
      "  ],\n",
      "  \"attention_probs_dropout_prob\": 0.1,\n",
      "  \"bos_token_id\": 0,\n",
      "  \"classifier_dropout\": null,\n",
      "  \"directionality\": \"bidi\",\n",
      "  \"eos_token_id\": 2,\n",
      "  \"hidden_act\": \"gelu\",\n",
      "  \"hidden_dropout_prob\": 0.1,\n",
      "  \"hidden_size\": 768,\n",
      "  \"initializer_range\": 0.02,\n",
      "  \"intermediate_size\": 3072,\n",
      "  \"layer_norm_eps\": 1e-12,\n",
      "  \"max_position_embeddings\": 512,\n",
      "  \"model_type\": \"bert\",\n",
      "  \"num_attention_heads\": 12,\n",
      "  \"num_hidden_layers\": 12,\n",
      "  \"output_past\": true,\n",
      "  \"pad_token_id\": 0,\n",
      "  \"pooler_fc_size\": 768,\n",
      "  \"pooler_num_attention_heads\": 12,\n",
      "  \"pooler_num_fc_layers\": 3,\n",
      "  \"pooler_size_per_head\": 128,\n",
      "  \"pooler_type\": \"first_token_transform\",\n",
      "  \"position_embedding_type\": \"absolute\",\n",
      "  \"transformers_version\": \"4.18.0\",\n",
      "  \"type_vocab_size\": 2,\n",
      "  \"use_cache\": true,\n",
      "  \"vocab_size\": 21128\n",
      "}\n",
      "\n",
      "using `logging_steps` to initialize `eval_steps` to 10000\n",
      "PyTorch: setting up devices\n",
      "The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).\n"
     ]
    }
   ],
   "source": [
    "# 加载模型\n",
    "config = AutoConfig.from_pretrained(model_checkpoint)\n",
    "model = AutoModelForMaskedLM.from_config(config)\n",
    "model.to(device)\n",
    "# 构建训练参数\n",
    "training_args = TrainingArguments(\n",
    "                          output_dir='pretrain_bert',\n",
    "                          overwrite_output_dir=True,\n",
    "                          do_train=True, \n",
    "                          do_eval=True,\n",
    "                          per_device_train_batch_size=32,\n",
    "                          per_device_eval_batch_size=100,\n",
    "                          evaluation_strategy='steps',\n",
    "                          logging_steps=10000,\n",
    "                          eval_steps = None,\n",
    "                          prediction_loss_only=True,\n",
    "                          learning_rate = 1e-5,\n",
    "                          weight_decay=0.01,\n",
    "                          adam_epsilon = 1e-8,\n",
    "                          max_grad_norm = 1.0,\n",
    "                          num_train_epochs = 10,\n",
    "                          save_steps = -1,\n",
    "                        push_to_hub=False\n",
    ")\n",
    "# 构建collator\n",
    "data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15)\n",
    "# 构建trainer\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset=train_dataset,\n",
    "    data_collator=data_collator\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "3e0330f6-b404-4422-be37-94c78a373744",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "***** Running training *****\n",
      "  Num examples = 20\n",
      "  Num Epochs = 10\n",
      "  Instantaneous batch size per device = 32\n",
      "  Total train batch size (w. parallel, distributed & accumulation) = 32\n",
      "  Gradient Accumulation steps = 1\n",
      "  Total optimization steps = 10\n"
     ]
    },
    {
     "ename": "RuntimeError",
     "evalue": "[enforce fail at ..\\c10\\core\\CPUAllocator.cpp:76] data. DefaultCPUAllocator: not enough memory: you tried to allocate 230496000 bytes.",
     "output_type": "error",
     "traceback": [
      "\u001B[1;31m---------------------------------------------------------------------------\u001B[0m",
      "\u001B[1;31mRuntimeError\u001B[0m                              Traceback (most recent call last)",
      "\u001B[1;32m<ipython-input-30-2a6fd8ec2e8f>\u001B[0m in \u001B[0;36m<module>\u001B[1;34m\u001B[0m\n\u001B[1;32m----> 1\u001B[1;33m \u001B[0mtrainer\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mtrain\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m      2\u001B[0m \u001B[0mtrainer\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0msave_model\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
      "\u001B[1;32mD:\\env\\Anaconda3\\lib\\site-packages\\transformers\\trainer.py\u001B[0m in \u001B[0;36mtrain\u001B[1;34m(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)\u001B[0m\n\u001B[0;32m   1420\u001B[0m                         \u001B[0mtr_loss_step\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mtraining_step\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mmodel\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0minputs\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m   1421\u001B[0m                 \u001B[1;32melse\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m-> 1422\u001B[1;33m                     \u001B[0mtr_loss_step\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mtraining_step\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mmodel\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0minputs\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m   1423\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m   1424\u001B[0m                 if (\n",
      "\u001B[1;32mD:\\env\\Anaconda3\\lib\\site-packages\\transformers\\trainer.py\u001B[0m in \u001B[0;36mtraining_step\u001B[1;34m(self, model, inputs)\u001B[0m\n\u001B[0;32m   2009\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m   2010\u001B[0m         \u001B[1;32mwith\u001B[0m \u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mautocast_smart_context_manager\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m-> 2011\u001B[1;33m             \u001B[0mloss\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mcompute_loss\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mmodel\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0minputs\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m   2012\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m   2013\u001B[0m         \u001B[1;32mif\u001B[0m \u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0margs\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mn_gpu\u001B[0m \u001B[1;33m>\u001B[0m \u001B[1;36m1\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
      "\u001B[1;32mD:\\env\\Anaconda3\\lib\\site-packages\\transformers\\trainer.py\u001B[0m in \u001B[0;36mcompute_loss\u001B[1;34m(self, model, inputs, return_outputs)\u001B[0m\n\u001B[0;32m   2041\u001B[0m         \u001B[1;32melse\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m   2042\u001B[0m             \u001B[0mlabels\u001B[0m \u001B[1;33m=\u001B[0m \u001B[1;32mNone\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m-> 2043\u001B[1;33m         \u001B[0moutputs\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mmodel\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;33m**\u001B[0m\u001B[0minputs\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m   2044\u001B[0m         \u001B[1;31m# Save past state if it exists\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m   2045\u001B[0m         \u001B[1;31m# TODO: this needs to be fixed and made cleaner later.\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
      "\u001B[1;32mD:\\env\\Anaconda3\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001B[0m in \u001B[0;36m_call_impl\u001B[1;34m(self, *input, **kwargs)\u001B[0m\n\u001B[0;32m   1100\u001B[0m         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001B[0;32m   1101\u001B[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001B[1;32m-> 1102\u001B[1;33m             \u001B[1;32mreturn\u001B[0m \u001B[0mforward_call\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;33m*\u001B[0m\u001B[0minput\u001B[0m\u001B[1;33m,\u001B[0m \u001B[1;33m**\u001B[0m\u001B[0mkwargs\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m   1103\u001B[0m         \u001B[1;31m# Do not call functions when jit is used\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m   1104\u001B[0m         \u001B[0mfull_backward_hooks\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mnon_full_backward_hooks\u001B[0m \u001B[1;33m=\u001B[0m \u001B[1;33m[\u001B[0m\u001B[1;33m]\u001B[0m\u001B[1;33m,\u001B[0m \u001B[1;33m[\u001B[0m\u001B[1;33m]\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
      "\u001B[1;32mD:\\env\\Anaconda3\\lib\\site-packages\\transformers\\models\\bert\\modeling_bert.py\u001B[0m in \u001B[0;36mforward\u001B[1;34m(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, labels, output_attentions, output_hidden_states, return_dict)\u001B[0m\n\u001B[0;32m   1341\u001B[0m         \u001B[0mreturn_dict\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mreturn_dict\u001B[0m \u001B[1;32mif\u001B[0m \u001B[0mreturn_dict\u001B[0m \u001B[1;32mis\u001B[0m \u001B[1;32mnot\u001B[0m \u001B[1;32mNone\u001B[0m \u001B[1;32melse\u001B[0m \u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mconfig\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0muse_return_dict\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m   1342\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m-> 1343\u001B[1;33m         outputs = self.bert(\n\u001B[0m\u001B[0;32m   1344\u001B[0m             \u001B[0minput_ids\u001B[0m\u001B[1;33m,\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m   1345\u001B[0m             \u001B[0mattention_mask\u001B[0m\u001B[1;33m=\u001B[0m\u001B[0mattention_mask\u001B[0m\u001B[1;33m,\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
      "\u001B[1;32mD:\\env\\Anaconda3\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001B[0m in \u001B[0;36m_call_impl\u001B[1;34m(self, *input, **kwargs)\u001B[0m\n\u001B[0;32m   1100\u001B[0m         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001B[0;32m   1101\u001B[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001B[1;32m-> 1102\u001B[1;33m             \u001B[1;32mreturn\u001B[0m \u001B[0mforward_call\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;33m*\u001B[0m\u001B[0minput\u001B[0m\u001B[1;33m,\u001B[0m \u001B[1;33m**\u001B[0m\u001B[0mkwargs\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m   1103\u001B[0m         \u001B[1;31m# Do not call functions when jit is used\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m   1104\u001B[0m         \u001B[0mfull_backward_hooks\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mnon_full_backward_hooks\u001B[0m \u001B[1;33m=\u001B[0m \u001B[1;33m[\u001B[0m\u001B[1;33m]\u001B[0m\u001B[1;33m,\u001B[0m \u001B[1;33m[\u001B[0m\u001B[1;33m]\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
      "\u001B[1;32mD:\\env\\Anaconda3\\lib\\site-packages\\transformers\\models\\bert\\modeling_bert.py\u001B[0m in \u001B[0;36mforward\u001B[1;34m(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)\u001B[0m\n\u001B[0;32m    994\u001B[0m             \u001B[0mpast_key_values_length\u001B[0m\u001B[1;33m=\u001B[0m\u001B[0mpast_key_values_length\u001B[0m\u001B[1;33m,\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m    995\u001B[0m         )\n\u001B[1;32m--> 996\u001B[1;33m         encoder_outputs = self.encoder(\n\u001B[0m\u001B[0;32m    997\u001B[0m             \u001B[0membedding_output\u001B[0m\u001B[1;33m,\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m    998\u001B[0m             \u001B[0mattention_mask\u001B[0m\u001B[1;33m=\u001B[0m\u001B[0mextended_attention_mask\u001B[0m\u001B[1;33m,\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
      "\u001B[1;32mD:\\env\\Anaconda3\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001B[0m in \u001B[0;36m_call_impl\u001B[1;34m(self, *input, **kwargs)\u001B[0m\n\u001B[0;32m   1100\u001B[0m         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001B[0;32m   1101\u001B[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001B[1;32m-> 1102\u001B[1;33m             \u001B[1;32mreturn\u001B[0m \u001B[0mforward_call\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;33m*\u001B[0m\u001B[0minput\u001B[0m\u001B[1;33m,\u001B[0m \u001B[1;33m**\u001B[0m\u001B[0mkwargs\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m   1103\u001B[0m         \u001B[1;31m# Do not call functions when jit is used\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m   1104\u001B[0m         \u001B[0mfull_backward_hooks\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mnon_full_backward_hooks\u001B[0m \u001B[1;33m=\u001B[0m \u001B[1;33m[\u001B[0m\u001B[1;33m]\u001B[0m\u001B[1;33m,\u001B[0m \u001B[1;33m[\u001B[0m\u001B[1;33m]\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
      "\u001B[1;32mD:\\env\\Anaconda3\\lib\\site-packages\\transformers\\models\\bert\\modeling_bert.py\u001B[0m in \u001B[0;36mforward\u001B[1;34m(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)\u001B[0m\n\u001B[0;32m    583\u001B[0m                 )\n\u001B[0;32m    584\u001B[0m             \u001B[1;32melse\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m--> 585\u001B[1;33m                 layer_outputs = layer_module(\n\u001B[0m\u001B[0;32m    586\u001B[0m                     \u001B[0mhidden_states\u001B[0m\u001B[1;33m,\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m    587\u001B[0m                     \u001B[0mattention_mask\u001B[0m\u001B[1;33m,\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
      "\u001B[1;32mD:\\env\\Anaconda3\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001B[0m in \u001B[0;36m_call_impl\u001B[1;34m(self, *input, **kwargs)\u001B[0m\n\u001B[0;32m   1100\u001B[0m         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001B[0;32m   1101\u001B[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001B[1;32m-> 1102\u001B[1;33m             \u001B[1;32mreturn\u001B[0m \u001B[0mforward_call\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;33m*\u001B[0m\u001B[0minput\u001B[0m\u001B[1;33m,\u001B[0m \u001B[1;33m**\u001B[0m\u001B[0mkwargs\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m   1103\u001B[0m         \u001B[1;31m# Do not call functions when jit is used\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m   1104\u001B[0m         \u001B[0mfull_backward_hooks\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mnon_full_backward_hooks\u001B[0m \u001B[1;33m=\u001B[0m \u001B[1;33m[\u001B[0m\u001B[1;33m]\u001B[0m\u001B[1;33m,\u001B[0m \u001B[1;33m[\u001B[0m\u001B[1;33m]\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
      "\u001B[1;32mD:\\env\\Anaconda3\\lib\\site-packages\\transformers\\models\\bert\\modeling_bert.py\u001B[0m in \u001B[0;36mforward\u001B[1;34m(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions)\u001B[0m\n\u001B[0;32m    470\u001B[0m         \u001B[1;31m# decoder uni-directional self-attention cached key/values tuple is at positions 1,2\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m    471\u001B[0m         \u001B[0mself_attn_past_key_value\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mpast_key_value\u001B[0m\u001B[1;33m[\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;36m2\u001B[0m\u001B[1;33m]\u001B[0m \u001B[1;32mif\u001B[0m \u001B[0mpast_key_value\u001B[0m \u001B[1;32mis\u001B[0m \u001B[1;32mnot\u001B[0m \u001B[1;32mNone\u001B[0m \u001B[1;32melse\u001B[0m \u001B[1;32mNone\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m--> 472\u001B[1;33m         self_attention_outputs = self.attention(\n\u001B[0m\u001B[0;32m    473\u001B[0m             \u001B[0mhidden_states\u001B[0m\u001B[1;33m,\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m    474\u001B[0m             \u001B[0mattention_mask\u001B[0m\u001B[1;33m,\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
      "\u001B[1;32mD:\\env\\Anaconda3\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001B[0m in \u001B[0;36m_call_impl\u001B[1;34m(self, *input, **kwargs)\u001B[0m\n\u001B[0;32m   1100\u001B[0m         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001B[0;32m   1101\u001B[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001B[1;32m-> 1102\u001B[1;33m             \u001B[1;32mreturn\u001B[0m \u001B[0mforward_call\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;33m*\u001B[0m\u001B[0minput\u001B[0m\u001B[1;33m,\u001B[0m \u001B[1;33m**\u001B[0m\u001B[0mkwargs\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m   1103\u001B[0m         \u001B[1;31m# Do not call functions when jit is used\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m   1104\u001B[0m         \u001B[0mfull_backward_hooks\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mnon_full_backward_hooks\u001B[0m \u001B[1;33m=\u001B[0m \u001B[1;33m[\u001B[0m\u001B[1;33m]\u001B[0m\u001B[1;33m,\u001B[0m \u001B[1;33m[\u001B[0m\u001B[1;33m]\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
      "\u001B[1;32mD:\\env\\Anaconda3\\lib\\site-packages\\transformers\\models\\bert\\modeling_bert.py\u001B[0m in \u001B[0;36mforward\u001B[1;34m(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions)\u001B[0m\n\u001B[0;32m    400\u001B[0m         \u001B[0moutput_attentions\u001B[0m\u001B[1;33m:\u001B[0m \u001B[0mOptional\u001B[0m\u001B[1;33m[\u001B[0m\u001B[0mbool\u001B[0m\u001B[1;33m]\u001B[0m \u001B[1;33m=\u001B[0m \u001B[1;32mFalse\u001B[0m\u001B[1;33m,\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m    401\u001B[0m     ) -> Tuple[torch.Tensor]:\n\u001B[1;32m--> 402\u001B[1;33m         self_outputs = self.self(\n\u001B[0m\u001B[0;32m    403\u001B[0m             \u001B[0mhidden_states\u001B[0m\u001B[1;33m,\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m    404\u001B[0m             \u001B[0mattention_mask\u001B[0m\u001B[1;33m,\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
      "\u001B[1;32mD:\\env\\Anaconda3\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001B[0m in \u001B[0;36m_call_impl\u001B[1;34m(self, *input, **kwargs)\u001B[0m\n\u001B[0;32m   1100\u001B[0m         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001B[0;32m   1101\u001B[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001B[1;32m-> 1102\u001B[1;33m             \u001B[1;32mreturn\u001B[0m \u001B[0mforward_call\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;33m*\u001B[0m\u001B[0minput\u001B[0m\u001B[1;33m,\u001B[0m \u001B[1;33m**\u001B[0m\u001B[0mkwargs\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m   1103\u001B[0m         \u001B[1;31m# Do not call functions when jit is used\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m   1104\u001B[0m         \u001B[0mfull_backward_hooks\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mnon_full_backward_hooks\u001B[0m \u001B[1;33m=\u001B[0m \u001B[1;33m[\u001B[0m\u001B[1;33m]\u001B[0m\u001B[1;33m,\u001B[0m \u001B[1;33m[\u001B[0m\u001B[1;33m]\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
      "\u001B[1;32mD:\\env\\Anaconda3\\lib\\site-packages\\transformers\\models\\bert\\modeling_bert.py\u001B[0m in \u001B[0;36mforward\u001B[1;34m(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions)\u001B[0m\n\u001B[0;32m    322\u001B[0m                 \u001B[0mattention_scores\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mattention_scores\u001B[0m \u001B[1;33m+\u001B[0m \u001B[0mrelative_position_scores_query\u001B[0m \u001B[1;33m+\u001B[0m \u001B[0mrelative_position_scores_key\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m    323\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m--> 324\u001B[1;33m         \u001B[0mattention_scores\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mattention_scores\u001B[0m \u001B[1;33m/\u001B[0m \u001B[0mmath\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0msqrt\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mattention_head_size\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m    325\u001B[0m         \u001B[1;32mif\u001B[0m \u001B[0mattention_mask\u001B[0m \u001B[1;32mis\u001B[0m \u001B[1;32mnot\u001B[0m \u001B[1;32mNone\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m    326\u001B[0m             \u001B[1;31m# Apply the attention mask is (precomputed for all layers in BertModel forward() function)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
      "\u001B[1;31mRuntimeError\u001B[0m: [enforce fail at ..\\c10\\core\\CPUAllocator.cpp:76] data. DefaultCPUAllocator: not enough memory: you tried to allocate 230496000 bytes."
     ]
    }
   ],
   "source": [
    "trainer.train()\n",
    "trainer.save_model()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f9cc7df-f9da-49e7-a775-c8c0e0280417",
   "metadata": {},
   "outputs": [],
   "source": [
    "! pwd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da5bf0bd-1679-46e2-9872-897d573aaf6d",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.save_model()\n",
    "print(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "419c63da-1ef0-4c72-a996-6003e0a8930b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}